import hashlib
import time
from typing import Dict, List, Tuple

from tensorlake.applications.interface import InternalError
from tensorlake.applications.metadata import serialize_metadata
from tensorlake.applications.user_data_serializer import serializer_by_name

from ..blob_store.blob_store import BLOBStore
from ..logger import FunctionExecutorLogger
from ..proto.function_executor_pb2 import (
    BLOB,
    SerializedObjectEncoding,
    SerializedObjectInsideBLOB,
    SerializedObjectManifest,
)
from .value import SerializedValue


def serialized_values_to_serialized_objects(
    serialized_values: Dict[str, SerializedValue],
) -> Tuple[Dict[str, SerializedObjectInsideBLOB], List[bytes]]:
    """Converts SerializedValues to SerializedObjectInsideBLOB and returns (SOs keyed by value IDs, SOs blob_data)."""
    serialized_objects: Dict[str, SerializedObjectInsideBLOB] = {}
    blob_data: List[bytes] = []
    blob_offset: int = 0
    encoding_version: int = 0

    for serialized_value in serialized_values.values():
        if serialized_value.metadata is None:
            raise InternalError(
                "SerializedValue.metadata cannot be None in output values generated by SDK."
            )

        serialized_metadata = serialize_metadata(serialized_value.metadata)
        encoding: SerializedObjectEncoding = (
            SerializedObjectEncoding.SERIALIZED_OBJECT_ENCODING_RAW
        )
        if serialized_value.metadata.serializer_name is not None:
            encoding = serializer_by_name(
                serialized_value.metadata.serializer_name
            ).serialized_object_encoding

        value_node_so: SerializedObjectInsideBLOB = SerializedObjectInsideBLOB(
            manifest=SerializedObjectManifest(
                encoding=encoding,
                encoding_version=encoding_version,
                size=len(serialized_metadata) + len(serialized_value.data),
                metadata_size=len(serialized_metadata),
                sha256_hash=_sha256_hexdigest(
                    serialized_metadata, serialized_value.data
                ),
                content_type=serialized_value.content_type,
            ),
            offset=blob_offset,
        )
        serialized_objects[serialized_value.metadata.id] = value_node_so
        blob_data.append(serialized_metadata)
        blob_data.append(serialized_value.data)
        blob_offset += value_node_so.manifest.size

    return serialized_objects, blob_data


def upload_serialized_objects_to_blob(
    serialized_objects: Dict[str, SerializedObjectInsideBLOB],
    blob_data: List[bytes],
    destination_blob: BLOB,
    blob_store: BLOBStore,
    logger: FunctionExecutorLogger,
) -> BLOB:
    """Uploads serialized values to the destination blob and returns uploaded BLOB with all chunks used."""

    total_size: int = sum(len(data) for data in blob_data)
    start_time = time.monotonic()
    logger.info(
        "uploading serialized objects to blob",
        objects_count=len(serialized_objects),
        total_size=total_size,
    )
    uploaded_blob: BLOB = _put_data_to_blob(
        blob_data=blob_data,
        destination=destination_blob,
        blob_store=blob_store,
        logger=logger,
    )
    logger.info(
        "uploaded serialized objects to blob",
        objects_count=len(serialized_objects),
        total_size=total_size,
        duration_sec=f"{time.monotonic() - start_time:.3f}",
    )

    return uploaded_blob


def upload_request_error(
    utf8_message: bytes,
    destination_blob: BLOB,
    blob_store: BLOBStore,
    logger: FunctionExecutorLogger,
) -> Tuple[SerializedObjectInsideBLOB, BLOB]:
    start_time = time.monotonic()
    logger.info(
        "uploading request error output",
        size=len(utf8_message),
    )
    uploaded_blob: BLOB = _put_data_to_blob(
        [utf8_message],
        destination_blob,
        blob_store,
        logger,
    )
    logger.info(
        "invocation error output uploaded",
        size=len(utf8_message),
        duration_sec=f"{time.monotonic() - start_time:.3f}",
    )

    return (
        SerializedObjectInsideBLOB(
            manifest=SerializedObjectManifest(
                encoding=SerializedObjectEncoding.SERIALIZED_OBJECT_ENCODING_UTF8_TEXT,
                encoding_version=0,
                size=len(utf8_message),
                metadata_size=0,
                sha256_hash=_sha256_hexdigest(b"", utf8_message),
            ),
            offset=0,
        ),
        uploaded_blob,
    )


def _put_data_to_blob(
    blob_data: List[bytes],
    destination: BLOB,
    blob_store: BLOBStore,
    logger: FunctionExecutorLogger,
) -> BLOB:
    """Uploads outputs to the blob and returns it with the updated chunks."""
    outputs_size: int = sum(len(output) for output in blob_data)
    blob_size: int = sum(chunk.size for chunk in destination.chunks)
    if outputs_size > blob_size:
        raise InternalError(
            f"Function output size {outputs_size} exceeds the total size of BLOB {blob_size}."
        )

    return blob_store.put(
        blob=destination,
        data=blob_data,
        logger=logger,
    )


def _sha256_hexdigest(metadata: bytes, data: bytes) -> str:
    hasher = hashlib.sha256()
    hasher.update(metadata)
    hasher.update(data)
    return hasher.hexdigest()
